import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from db_utils import *

def matplotlib3():
    data = get_class_averages_all_subjects()
    if not data:
        print("没有数据可以显示.")
        return

    classes = [item['class'] for item in data]
    subjects = ["语文", "数学", "英语"]
    score_keys = ['chinese_avg', 'math_avg', 'english_avg']

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    x_labels = classes
    y_labels = subjects
    x_pos = np.arange(len(x_labels))
    y_pos = np.arange(len(y_labels))

    x_pos_mesh, y_pos_mesh = np.meshgrid(x_pos, y_pos,indexing='ij')
    z_pos = np.zeros_like(x_pos_mesh)

    dx = dy = 0.5
    dz = []
    for i in range(len(data)):
        dz.append([
            data[i]['chinese_avg'],
            data[i]['math_avg'],
            data[i]['english_avg']
        ])

    dz = np.array(dz)

    colors = ['#FFB6C1', '#98FB98', '#87CEEB']

    for i in range(len(subjects)):
        ax.bar3d(
            x_pos_mesh[:, i],
            y_pos_mesh[:, i],
            z_pos[:, i],
            dx,
            dy,
            dz[:, i],
            color=colors[i],
            shade=True
        )

    ax.set_xticks(x_pos)
    ax.set_xticklabels(x_labels)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(y_labels)
    ax.set_zlabel('平均分')

    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.title('各班级在不同科目中的平均分（三维柱状图)')
    plt.tight_layout()
    plt.show()